@@ -66,7 +66,7 @@ def __init__(
66
66
eps = 0 ,
67
67
fixed_node_set = None ,
68
68
normalize = True ,
69
- progress = False
69
+ progress = False ,
70
70
):
71
71
self .ts = ts
72
72
self .timepoints = timepoints
@@ -694,13 +694,14 @@ def outside_pass(
694
694
normalize = False ,
695
695
ignore_oldest_root = False ,
696
696
progress = None ,
697
- probability_space_returned = base .LIN
698
697
):
699
698
"""
700
- Computes the full posterior distribution on nodes.
699
+ Computes the full posterior distribution on nodes, returning the
700
+ posterior values. These are *not* probabilities, as they do not sum to one:
701
+ to convert to probabilities, call posterior.to_probabilities()
701
702
702
- Normalising may be necessary if there is overflow, but means that we cannot
703
- check the total functional value at each node
703
+ Normalising *during* the outside process may be necessary if there is overflow,
704
+ but means that we cannot check the total functional value at each node
704
705
705
706
Ignoring the oldest root may also be necessary when the oldest root node
706
707
causes numerical stability issues.
@@ -769,13 +770,11 @@ def outside_pass(
769
770
outside [child ] = self .lik .reduce (val , self .norm [child ])
770
771
if normalize :
771
772
outside [child ] = self .lik .reduce (val , np .max (val ))
773
+ self .outside = outside
772
774
posterior = outside .clone_with_new_data (
773
775
grid_data = self .lik .combine (self .inside .grid_data , outside .grid_data ),
774
776
fixed_data = np .nan ,
775
777
) # We should never use the posterior for a fixed node
776
- posterior .normalize ()
777
- posterior .force_probability_space (probability_space_returned )
778
- self .outside = outside
779
778
return posterior
780
779
781
780
def outside_maximization (self , * , eps , progress = None ):
@@ -857,12 +856,12 @@ def outside_maximization(self, *, eps, progress=None):
857
856
858
857
def posterior_mean_var (ts , posterior , * , fixed_node_set = None ):
859
858
"""
860
- Mean and variance of node age in unscaled time . Fixed nodes will be given a mean
859
+ Mean and variance of node age. Fixed nodes will be given a mean
861
860
of their exact time in the tree sequence, and zero variance (as long as they are
862
- identified by the fixed_node_set
861
+ identified by the fixed_node_set).
863
862
If fixed_node_set is None, we attempt to date all the non-sample nodes
864
- Also assigns the estimated mean and variance of the age of each node, in unscaled
865
- time, as metadata in the tree sequence.
863
+ Also assigns the estimated mean and variance of the age of each node
864
+ as metadata in the tree sequence.
866
865
"""
867
866
mn_post = np .full (ts .num_nodes , np .nan ) # Fill with NaNs so we detect when there's
868
867
vr_post = np .full (ts .num_nodes , np .nan ) # been an error
@@ -936,7 +935,7 @@ def date(
936
935
* ,
937
936
return_posteriors = None ,
938
937
progress = False ,
939
- ** kwargs
938
+ ** kwargs ,
940
939
):
941
940
"""
942
941
Take a tree sequence (which could have
@@ -948,6 +947,19 @@ def date(
948
947
mutations and non-sample nodes in the input tree sequence are not used in inference
949
948
and will be removed.
950
949
950
+ .. note::
951
+ If posteriors are returned via the ``return_posteriors`` option, the output will
952
+ be a tuple ``(ts, posteriors)``, where ``posteriors`` is a dictionary suitable
953
+ for reading as a pandas ``DataFrame`` object, using ``pd.DataFrame(posteriors)``.
954
+ Each node whose time was inferred corresponds to an item in this dictionary,
955
+ with the key being the node ID and the value a 1D array of probabilities of the
956
+ node being in a given time slice (or ``None`` if the "inside_outside" method
957
+ was not used). The start and end times of each time slice are given as 1D
958
+ arrays in the dictionary, under keys named ``"start_time"`` and ``end_time"``.
959
+ As timeslices may not be not of uniform width, it is important to divide the
960
+ posterior probabilities by ``end_time - start_time`` when assessing the shape
961
+ of the probability density function over time.
962
+
951
963
:param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as
952
964
one whose non-sample nodes are undated.
953
965
:param float Ne: The estimated (diploid) effective population size used to construct
@@ -974,9 +986,7 @@ def date(
974
986
conditional coalescent prior with a standard set of time points as given by
975
987
:func:`build_prior_grid`.
976
988
:param bool return_posteriors: If ``True``, instead of returning just a dated tree
977
- sequence, return a tuple of ``(dated_ts, posteriors)``. Note that the dictionary
978
- returned in ``posteriors`` (described below) is suitable for reading as a pandas
979
- ``DataFrame`` object, using ``pd.DataFrame(posteriors)``.
989
+ sequence, return a tuple of ``(dated_ts, posteriors)`` (see note above).
980
990
:param float eps: Specify minimum distance separating time points. Also specifies
981
991
the error factor in time difference calculations. Default: 1e-6
982
992
:param int num_threads: The number of threads to use. A simpler unthreaded algorithm
@@ -996,11 +1006,6 @@ def date(
996
1006
:return: A copy of the input tree sequence but with altered node times, or (if
997
1007
``return_posteriors`` is True) a tuple of that tree sequence plus a dictionary
998
1008
of posterior probabilities from the "inside_outside" estimation ``method``.
999
- Each node whose time was inferred corresponds to an item in this dictionary,
1000
- with the key being the node ID and the value a 1D array of probabilities of the
1001
- node being in a given time slice (or ``None`` if the "inside_outside" method
1002
- was not used). The start and end times of each time slice are given as 1D
1003
- arrays in the dictionary, under keys named ``"start_time"`` and ``end_time"``.
1004
1009
:rtype: tskit.TreeSequence or (tskit.TreeSequence, dict)
1005
1010
"""
1006
1011
if time_units is None :
@@ -1012,7 +1017,7 @@ def date(
1012
1017
recombination_rate = recombination_rate ,
1013
1018
priors = priors ,
1014
1019
progress = progress ,
1015
- ** kwargs
1020
+ ** kwargs ,
1016
1021
)
1017
1022
constrained = constrain_ages_topo (tree_sequence , dates , eps , nds , progress )
1018
1023
tables = tree_sequence .dump_tables ()
@@ -1028,12 +1033,12 @@ def date(
1028
1033
Ne = Ne ,
1029
1034
recombination_rate = recombination_rate ,
1030
1035
progress = progress ,
1031
- ** kwargs
1036
+ ** kwargs ,
1032
1037
)
1033
1038
if return_posteriors :
1034
1039
pst = {"start_time" : timepoints , "end_time" : np .append (timepoints [1 :], np .inf )}
1035
- for i , n in enumerate ( nds ) :
1036
- pst [n ] = None if posteriors is None else posteriors . grid_data [ i , : ]
1040
+ for n in nds :
1041
+ pst [n ] = None if posteriors is None else posteriors [ n ]
1037
1042
return tables .tree_sequence (), pst
1038
1043
else :
1039
1044
return tables .tree_sequence ()
@@ -1053,15 +1058,18 @@ def get_dates(
1053
1058
ignore_oldest_root = False ,
1054
1059
progress = False ,
1055
1060
cache_inside = False ,
1056
- probability_space = base .LOG
1061
+ probability_space = base .LOG ,
1057
1062
):
1058
1063
"""
1059
1064
Infer dates for the nodes in a tree sequence, returning an array of inferred dates
1060
- for nodes, plus other variables such as the distribution of posterior probabilities
1065
+ for nodes, plus other variables such as the posteriors object
1061
1066
etc. Parameters are identical to the date() method, which calls this method, then
1062
1067
injects the resulting date estimates into the tree sequence
1063
1068
1064
- :return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
1069
+ :return: a tuple of ``(mn_post, posteriors, timepoints, eps, nodes_to_date)``.
1070
+ If the "inside_outside" method is used, ``posteriors`` will contain the
1071
+ posterior probabilities for each node in each time slice, else the returned
1072
+ variable will be ``None``.
1065
1073
"""
1066
1074
# Stuff yet to be implemented. These can be deleted once fixed
1067
1075
for sample in tree_sequence .samples ():
@@ -1128,6 +1136,10 @@ def get_dates(
1128
1136
posterior = dynamic_prog .outside_pass (
1129
1137
normalize = outside_normalize , ignore_oldest_root = ignore_oldest_root
1130
1138
)
1139
+ # Turn the posterior into probabilities
1140
+ posterior .normalize () # Just to make sure there are no floating point issues
1141
+ posterior .force_probability_space (base .LIN )
1142
+ posterior .to_probabilities ()
1131
1143
tree_sequence , mn_post , _ = posterior_mean_var (
1132
1144
tree_sequence , posterior , fixed_node_set = fixed_nodes
1133
1145
)
0 commit comments