Skip to content

Commit d56c63d

Browse files
authored
Merge pull request #232 from hyanwong/fix-posteriors
Fix unit test and posteriors
2 parents e9bd098 + 5cc7fad commit d56c63d

File tree

5 files changed

+113
-51
lines changed

5 files changed

+113
-51
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
individuals, populations, or sites, aiming to change the tree sequence tables as
99
little as possible.
1010

11+
**Bugfixes**
12+
13+
- The returned posteriors when ``return_posteriors=True`` now return actual
14+
probabilities (scaled so that they sum to one) rather than normalised
15+
probabilites whose maximum value is one.
16+
1117
--------------------
1218
[0.1.5] - 2022-06-07
1319
--------------------

tests/test_functions.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import utility_functions
3838

3939
import tsdate
40-
from tsdate.base import NodeGridValues
40+
from tsdate import base
4141
from tsdate.core import constrain_ages_topo
4242
from tsdate.core import date
4343
from tsdate.core import get_dates
@@ -797,14 +797,14 @@ def test_init(self):
797797
num_nodes = 5
798798
ids = np.array([3, 4])
799799
timepoints = np.array(range(10))
800-
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=6)
800+
store = base.NodeGridValues(num_nodes, ids, timepoints, fill_value=6)
801801
assert store.grid_data.shape == (len(ids), len(timepoints))
802802
assert len(store.fixed_data) == (num_nodes - len(ids))
803803
assert np.all(store.grid_data == 6)
804804
assert np.all(store.fixed_data == 6)
805805

806806
ids = np.array([3, 4], dtype=np.int32)
807-
store = NodeGridValues(num_nodes, ids, timepoints, fill_value=5)
807+
store = base.NodeGridValues(num_nodes, ids, timepoints, fill_value=5)
808808
assert store.grid_data.shape == (len(ids), len(timepoints))
809809
assert len(store.fixed_data) == num_nodes - len(ids)
810810
assert np.all(store.fixed_data == 5)
@@ -815,7 +815,7 @@ def test_set_and_get(self):
815815
fill = {}
816816
for ids in ([3, 4], []):
817817
np.random.seed(1)
818-
store = NodeGridValues(
818+
store = base.NodeGridValues(
819819
num_nodes, np.array(ids, dtype=np.int32), np.array(range(grid_size))
820820
)
821821
for i in range(num_nodes):
@@ -829,48 +829,52 @@ def test_set_and_get(self):
829829
def test_bad_init(self):
830830
ids = [3, 4]
831831
with pytest.raises(ValueError):
832-
NodeGridValues(3, np.array(ids), np.array([0, 1.2, 2]))
832+
base.NodeGridValues(3, np.array(ids), np.array([0, 1.2, 2]))
833833
with pytest.raises(AttributeError):
834-
NodeGridValues(5, np.array(ids), -1)
834+
base.NodeGridValues(5, np.array(ids), -1)
835835
with pytest.raises(ValueError):
836-
NodeGridValues(5, np.array([-1]), np.array([0, 1.2, 2]))
836+
base.NodeGridValues(5, np.array([-1]), np.array([0, 1.2, 2]))
837837

838838
def test_clone(self):
839839
num_nodes = 10
840840
grid_size = 2
841841
ids = [3, 4]
842-
orig = NodeGridValues(num_nodes, np.array(ids), np.array(range(grid_size)))
842+
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array(range(grid_size)))
843843
orig[3] = np.array([1, 2])
844844
orig[4] = np.array([4, 3])
845845
orig[0] = 1.5
846846
orig[9] = 2.5
847847
# test with np.zeros
848-
clone = NodeGridValues.clone_with_new_data(orig, 0)
848+
clone = base.NodeGridValues.clone_with_new_data(orig, 0)
849849
assert clone.grid_data.shape == orig.grid_data.shape
850850
assert clone.fixed_data.shape == orig.fixed_data.shape
851851
assert np.all(clone.grid_data == 0)
852852
assert np.all(clone.fixed_data == 0)
853853
# test with something else
854-
clone = NodeGridValues.clone_with_new_data(orig, 5)
854+
clone = base.NodeGridValues.clone_with_new_data(orig, 5)
855855
assert clone.grid_data.shape == orig.grid_data.shape
856856
assert clone.fixed_data.shape == orig.fixed_data.shape
857857
assert np.all(clone.grid_data == 5)
858858
assert np.all(clone.fixed_data == 5)
859859
# test with different
860860
scalars = np.arange(num_nodes - len(ids))
861-
clone = NodeGridValues.clone_with_new_data(orig, 0, scalars)
861+
clone = base.NodeGridValues.clone_with_new_data(orig, 0, scalars)
862862
assert clone.grid_data.shape == orig.grid_data.shape
863863
assert clone.fixed_data.shape == orig.fixed_data.shape
864864
assert np.all(clone.grid_data == 0)
865865
assert np.all(clone.fixed_data == scalars)
866866

867-
clone = NodeGridValues.clone_with_new_data(orig, np.array([[1, 2], [4, 3]]))
867+
clone = base.NodeGridValues.clone_with_new_data(
868+
orig, np.array([[1, 2], [4, 3]])
869+
)
868870
for i in range(num_nodes):
869871
if i in ids:
870872
assert np.all(clone[i] == orig[i])
871873
else:
872874
assert np.isnan(clone[i])
873-
clone = NodeGridValues.clone_with_new_data(orig, np.array([[1, 2], [4, 3]]), 0)
875+
clone = base.NodeGridValues.clone_with_new_data(
876+
orig, np.array([[1, 2], [4, 3]]), 0
877+
)
874878
for i in range(num_nodes):
875879
if i in ids:
876880
assert np.all(clone[i] == orig[i])
@@ -880,19 +884,44 @@ def test_clone(self):
880884
def test_bad_clone(self):
881885
num_nodes = 10
882886
ids = [3, 4]
883-
orig = NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
887+
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
884888
with pytest.raises(ValueError):
885-
NodeGridValues.clone_with_new_data(
889+
base.NodeGridValues.clone_with_new_data(
886890
orig,
887891
np.array([[1, 2, 3], [4, 5, 6]]),
888892
)
889893
with pytest.raises(ValueError):
890-
NodeGridValues.clone_with_new_data(
894+
base.NodeGridValues.clone_with_new_data(
891895
orig,
892896
0,
893897
np.array([[1, 2], [4, 5]]),
894898
)
895899

900+
def test_convert_to_probs(self):
901+
num_nodes = 10
902+
ids = [3, 4]
903+
make_nan_row = 4
904+
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]), 1)
905+
orig[make_nan_row][0] = np.nan
906+
assert np.all(np.isnan(orig[make_nan_row]) == [True, False])
907+
orig.force_probability_space(base.LIN)
908+
orig.to_probabilities()
909+
for n in orig.nonfixed_nodes:
910+
if n == make_nan_row:
911+
assert np.all(np.isnan(orig[n]))
912+
else:
913+
assert np.allclose(np.sum(orig[n]), 1)
914+
assert np.all(orig[n] >= 0)
915+
916+
def test_cannot_convert_to_probs(self):
917+
# No class implemention of logsumexp to convert to probabilities in log space
918+
num_nodes = 10
919+
ids = [3, 4]
920+
orig = base.NodeGridValues(num_nodes, np.array(ids), np.array([0, 1.2]))
921+
orig.force_probability_space(base.LOG)
922+
with pytest.raises(NotImplementedError, match="linear space"):
923+
orig.to_probabilities()
924+
896925

897926
class TestAlgorithmClass:
898927
def test_nonmatching_prior_vs_lik_timepoints(self):

tests/test_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# MIT License
22
#
3+
# Copyright (c) 2021-23 Tskit Developers
34
# Copyright (c) 2020 University of Oxford
45
#
56
# Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -109,7 +110,7 @@ def test_no_posteriors(self):
109110
assert len(posteriors["start_time"]) == len(posteriors["end_time"])
110111
assert len(posteriors["start_time"]) > 0
111112
for node in ts.nodes():
112-
if not node.is_sample:
113+
if not node.is_sample():
113114
assert node.id in posteriors
114115
assert posteriors[node.id] is None
115116

@@ -122,7 +123,7 @@ def test_posteriors(self):
122123
assert len(posteriors["start_time"]) == len(posteriors["end_time"])
123124
assert len(posteriors["start_time"]) > 0
124125
for node in ts.nodes():
125-
if not node.is_sample:
126+
if not node.is_sample():
126127
assert node.id in posteriors
127128
assert len(posteriors[node.id]) == len(posteriors["start_time"])
128129
assert np.isclose(np.sum(posteriors[node.id]), 1)

tsdate/base.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@
3838

3939
class NodeGridValues:
4040
"""
41-
A class to store grid values for node ids. For some nodes (fixed ones), only a single
42-
value needs to be stored. For non-fixed nodes, an array of grid_size variables
43-
is required, e.g. in order to store all the possible values for each of the hidden
44-
states in the grid
41+
A class to store times or discretised distributions of times for node ids. For nodes
42+
with fixed times, only a single time value needs to be stored. For non-fixed nodes,
43+
an array of len(timepoints) probabilies is required.
4544
4645
:ivar num_nodes: The number of nodes that will be stored in this object
4746
:vartype num_nodes: int
@@ -130,7 +129,10 @@ def force_probability_space(self, probability_space):
130129

131130
def normalize(self):
132131
"""
133-
normalize grid and fixed data so the max is one
132+
normalize grid data so the max is one (in linear space) or zero
133+
(in logarithmic space)
134+
135+
TODO - is it clear why we omit the first element of the
134136
"""
135137
rowmax = self.grid_data[:, 1:].max(axis=1)
136138
if self.probability_space == LIN:
@@ -140,6 +142,18 @@ def normalize(self):
140142
else:
141143
raise RuntimeError("Probability space is not", LIN, "or", LOG)
142144

145+
def to_probabilities(self):
146+
"""
147+
Change grid data into probabilities (i.e. each row sums to one in linear or zero
148+
in logarithmic space)
149+
"""
150+
if self.probability_space != LIN:
151+
raise NotImplementedError(
152+
"Can only convert to probabilities in linear space"
153+
)
154+
assert not np.any(self.grid_data < 0)
155+
self.grid_data = self.grid_data / self.grid_data.sum(axis=1)[:, np.newaxis]
156+
143157
def __getitem__(self, node_id):
144158
index = self.row_lookup[node_id]
145159
if index < 0:

tsdate/core.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
eps=0,
6767
fixed_node_set=None,
6868
normalize=True,
69-
progress=False
69+
progress=False,
7070
):
7171
self.ts = ts
7272
self.timepoints = timepoints
@@ -694,13 +694,14 @@ def outside_pass(
694694
normalize=False,
695695
ignore_oldest_root=False,
696696
progress=None,
697-
probability_space_returned=base.LIN
698697
):
699698
"""
700-
Computes the full posterior distribution on nodes.
699+
Computes the full posterior distribution on nodes, returning the
700+
posterior values. These are *not* probabilities, as they do not sum to one:
701+
to convert to probabilities, call posterior.to_probabilities()
701702
702-
Normalising may be necessary if there is overflow, but means that we cannot
703-
check the total functional value at each node
703+
Normalising *during* the outside process may be necessary if there is overflow,
704+
but means that we cannot check the total functional value at each node
704705
705706
Ignoring the oldest root may also be necessary when the oldest root node
706707
causes numerical stability issues.
@@ -769,13 +770,11 @@ def outside_pass(
769770
outside[child] = self.lik.reduce(val, self.norm[child])
770771
if normalize:
771772
outside[child] = self.lik.reduce(val, np.max(val))
773+
self.outside = outside
772774
posterior = outside.clone_with_new_data(
773775
grid_data=self.lik.combine(self.inside.grid_data, outside.grid_data),
774776
fixed_data=np.nan,
775777
) # We should never use the posterior for a fixed node
776-
posterior.normalize()
777-
posterior.force_probability_space(probability_space_returned)
778-
self.outside = outside
779778
return posterior
780779

781780
def outside_maximization(self, *, eps, progress=None):
@@ -857,12 +856,12 @@ def outside_maximization(self, *, eps, progress=None):
857856

858857
def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
859858
"""
860-
Mean and variance of node age in unscaled time. Fixed nodes will be given a mean
859+
Mean and variance of node age. Fixed nodes will be given a mean
861860
of their exact time in the tree sequence, and zero variance (as long as they are
862-
identified by the fixed_node_set
861+
identified by the fixed_node_set).
863862
If fixed_node_set is None, we attempt to date all the non-sample nodes
864-
Also assigns the estimated mean and variance of the age of each node, in unscaled
865-
time, as metadata in the tree sequence.
863+
Also assigns the estimated mean and variance of the age of each node
864+
as metadata in the tree sequence.
866865
"""
867866
mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when there's
868867
vr_post = np.full(ts.num_nodes, np.nan) # been an error
@@ -936,7 +935,7 @@ def date(
936935
*,
937936
return_posteriors=None,
938937
progress=False,
939-
**kwargs
938+
**kwargs,
940939
):
941940
"""
942941
Take a tree sequence (which could have
@@ -948,6 +947,19 @@ def date(
948947
mutations and non-sample nodes in the input tree sequence are not used in inference
949948
and will be removed.
950949
950+
.. note::
951+
If posteriors are returned via the ``return_posteriors`` option, the output will
952+
be a tuple ``(ts, posteriors)``, where ``posteriors`` is a dictionary suitable
953+
for reading as a pandas ``DataFrame`` object, using ``pd.DataFrame(posteriors)``.
954+
Each node whose time was inferred corresponds to an item in this dictionary,
955+
with the key being the node ID and the value a 1D array of probabilities of the
956+
node being in a given time slice (or ``None`` if the "inside_outside" method
957+
was not used). The start and end times of each time slice are given as 1D
958+
arrays in the dictionary, under keys named ``"start_time"`` and ``end_time"``.
959+
As timeslices may not be not of uniform width, it is important to divide the
960+
posterior probabilities by ``end_time - start_time`` when assessing the shape
961+
of the probability density function over time.
962+
951963
:param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as
952964
one whose non-sample nodes are undated.
953965
:param float Ne: The estimated (diploid) effective population size used to construct
@@ -974,9 +986,7 @@ def date(
974986
conditional coalescent prior with a standard set of time points as given by
975987
:func:`build_prior_grid`.
976988
:param bool return_posteriors: If ``True``, instead of returning just a dated tree
977-
sequence, return a tuple of ``(dated_ts, posteriors)``. Note that the dictionary
978-
returned in ``posteriors`` (described below) is suitable for reading as a pandas
979-
``DataFrame`` object, using ``pd.DataFrame(posteriors)``.
989+
sequence, return a tuple of ``(dated_ts, posteriors)`` (see note above).
980990
:param float eps: Specify minimum distance separating time points. Also specifies
981991
the error factor in time difference calculations. Default: 1e-6
982992
:param int num_threads: The number of threads to use. A simpler unthreaded algorithm
@@ -996,11 +1006,6 @@ def date(
9961006
:return: A copy of the input tree sequence but with altered node times, or (if
9971007
``return_posteriors`` is True) a tuple of that tree sequence plus a dictionary
9981008
of posterior probabilities from the "inside_outside" estimation ``method``.
999-
Each node whose time was inferred corresponds to an item in this dictionary,
1000-
with the key being the node ID and the value a 1D array of probabilities of the
1001-
node being in a given time slice (or ``None`` if the "inside_outside" method
1002-
was not used). The start and end times of each time slice are given as 1D
1003-
arrays in the dictionary, under keys named ``"start_time"`` and ``end_time"``.
10041009
:rtype: tskit.TreeSequence or (tskit.TreeSequence, dict)
10051010
"""
10061011
if time_units is None:
@@ -1012,7 +1017,7 @@ def date(
10121017
recombination_rate=recombination_rate,
10131018
priors=priors,
10141019
progress=progress,
1015-
**kwargs
1020+
**kwargs,
10161021
)
10171022
constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress)
10181023
tables = tree_sequence.dump_tables()
@@ -1028,12 +1033,12 @@ def date(
10281033
Ne=Ne,
10291034
recombination_rate=recombination_rate,
10301035
progress=progress,
1031-
**kwargs
1036+
**kwargs,
10321037
)
10331038
if return_posteriors:
10341039
pst = {"start_time": timepoints, "end_time": np.append(timepoints[1:], np.inf)}
1035-
for i, n in enumerate(nds):
1036-
pst[n] = None if posteriors is None else posteriors.grid_data[i, :]
1040+
for n in nds:
1041+
pst[n] = None if posteriors is None else posteriors[n]
10371042
return tables.tree_sequence(), pst
10381043
else:
10391044
return tables.tree_sequence()
@@ -1053,15 +1058,18 @@ def get_dates(
10531058
ignore_oldest_root=False,
10541059
progress=False,
10551060
cache_inside=False,
1056-
probability_space=base.LOG
1061+
probability_space=base.LOG,
10571062
):
10581063
"""
10591064
Infer dates for the nodes in a tree sequence, returning an array of inferred dates
1060-
for nodes, plus other variables such as the distribution of posterior probabilities
1065+
for nodes, plus other variables such as the posteriors object
10611066
etc. Parameters are identical to the date() method, which calls this method, then
10621067
injects the resulting date estimates into the tree sequence
10631068
1064-
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
1069+
:return: a tuple of ``(mn_post, posteriors, timepoints, eps, nodes_to_date)``.
1070+
If the "inside_outside" method is used, ``posteriors`` will contain the
1071+
posterior probabilities for each node in each time slice, else the returned
1072+
variable will be ``None``.
10651073
"""
10661074
# Stuff yet to be implemented. These can be deleted once fixed
10671075
for sample in tree_sequence.samples():
@@ -1128,6 +1136,10 @@ def get_dates(
11281136
posterior = dynamic_prog.outside_pass(
11291137
normalize=outside_normalize, ignore_oldest_root=ignore_oldest_root
11301138
)
1139+
# Turn the posterior into probabilities
1140+
posterior.normalize() # Just to make sure there are no floating point issues
1141+
posterior.force_probability_space(base.LIN)
1142+
posterior.to_probabilities()
11311143
tree_sequence, mn_post, _ = posterior_mean_var(
11321144
tree_sequence, posterior, fixed_node_set=fixed_nodes
11331145
)

0 commit comments

Comments
 (0)