@@ -952,29 +952,55 @@ def gamma_cdf(t_set, alpha, beta):
952
952
return np .insert (t_set , 0 , 0 )
953
953
954
954
955
- def fill_priors (node_parameters , timepoints , ts , Ne , * , prior_distr , progress = False ):
955
+ def fill_priors (
956
+ node_parameters ,
957
+ timepoints ,
958
+ ts ,
959
+ Ne ,
960
+ * ,
961
+ prior_distr ,
962
+ node_var_override = None ,
963
+ progress = False ,
964
+ ):
956
965
"""
957
966
Take the alpha and beta values from the node_parameters array, which contains
958
- one row for each node in the TS (including fixed nodes)
959
- and fill out a NodeGridValues object with the prior values from the
960
- gamma or lognormal distribution with those parameters.
967
+ one row for each node in the TS (including fixed nodes, although alpha and beta
968
+ are ignored for these nodes) and fill out a NodeGridValues object with the prior
969
+ values from the gamma or lognormal distribution with those parameters.
970
+
971
+ For a description of `node_var_override`, see the parameter description in
972
+ the `build_grid` function.
961
973
962
974
TODO - what if there is an internal fixed node? Should we truncate
963
975
"""
964
976
if prior_distr == "lognorm" :
965
977
cdf_func = scipy .stats .lognorm .cdf
966
- main_param = np .sqrt (node_parameters [:, PriorParams .field_index ("beta" )])
978
+ shape_param = np .sqrt (node_parameters [:, PriorParams .field_index ("beta" )])
967
979
scale_param = np .exp (node_parameters [:, PriorParams .field_index ("alpha" )])
980
+
981
+ def shape_scale_from_mean_var (mean , var ):
982
+ a , b = lognorm_approx (mean , var )
983
+ return np .sqrt (b ), np .exp (a )
984
+
968
985
elif prior_distr == "gamma" :
969
986
cdf_func = scipy .stats .gamma .cdf
970
- main_param = node_parameters [:, PriorParams .field_index ("alpha" )]
971
- scale_param = 1 / node_parameters [:, PriorParams .field_index ("beta" )]
987
+ shape_param = node_parameters [:, PriorParams .field_index ("alpha" )]
988
+ scale_param = 1.0 / node_parameters [:, PriorParams .field_index ("beta" )]
989
+
990
+ def shape_scale_from_mean_var (mean , var ):
991
+ a , b = gamma_approx (mean , var )
992
+ return a , 1.0 / b
993
+
972
994
else :
973
995
raise ValueError ("prior distribution must be lognorm or gamma" )
974
-
996
+ if node_var_override is None :
997
+ node_var_override = {}
975
998
datable_nodes = np .ones (ts .num_nodes , dtype = bool )
976
999
datable_nodes [ts .samples ()] = False
1000
+ # Mark all nodes in node_var_override as datable
1001
+ datable_nodes [list (node_var_override .keys ())] = True
977
1002
datable_nodes = np .where (datable_nodes )[0 ]
1003
+
978
1004
prior_times = base .NodeGridValues (
979
1005
ts .num_nodes ,
980
1006
datable_nodes [np .argsort (ts .tables .nodes .time [datable_nodes ])].astype (np .int32 ),
@@ -985,8 +1011,16 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
985
1011
for node in tqdm (
986
1012
datable_nodes , desc = "Assign Prior to Each Node" , disable = not progress
987
1013
):
1014
+ if node in node_var_override :
1015
+ shape , scale = shape_scale_from_mean_var (
1016
+ mean = ts .node (node ).time ,
1017
+ var = node_var_override [node ],
1018
+ )
1019
+ else :
1020
+ shape = shape_param [node ]
1021
+ scale = scale_param [node ]
988
1022
with np .errstate (divide = "ignore" , invalid = "ignore" ):
989
- prior_node = cdf_func (timepoints , main_param [ node ] , scale = scale_param [ node ] )
1023
+ prior_node = cdf_func (timepoints , shape , scale = scale )
990
1024
# force age to be less than max value
991
1025
prior_node = np .divide (prior_node , np .max (prior_node ))
992
1026
# prior in each epoch
@@ -999,7 +1033,7 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
999
1033
def _truncate_priors (ts , priors , progress = False ):
1000
1034
"""
1001
1035
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1002
- if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
1036
+ if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
1003
1037
sequence
1004
1038
"""
1005
1039
tables = ts .tables
@@ -1065,6 +1099,7 @@ def build_grid(
1065
1099
prior_distribution = "lognorm" ,
1066
1100
allow_historical_samples = None ,
1067
1101
truncate_priors = None ,
1102
+ node_var_override = None ,
1068
1103
eps = 1e-6 ,
1069
1104
# Parameters below undocumented
1070
1105
progress = False ,
@@ -1100,6 +1135,13 @@ def build_grid(
1100
1135
priors of their direct ancestor nodes so that the probability of being younger
1101
1136
than the oldest descendant sample is zero. If the tree sequence is trustworthy
1102
1137
this should give better restults. Default: `True`
1138
+ :param dict node_var_override: is a dict mapping node IDs to a variance value.
1139
+ Any nodes listed here will be treated as non-fixed nodes whose prior is not
1140
+ calculated from the conditional coalescent but instead are allocated a prior
1141
+ whose mean is thenode time in the tree sequence and whose variance is the
1142
+ value in this dictionary. This allows sample nodes to be treated as nonfixed
1143
+ nodes, and therefore dated. If ``None`` (default) then all sample nodes are
1144
+ treated as occurring ata fixed time (as if this were an empty dict).
1103
1145
:param float eps: Specify minimum distance separating points in the time grid. Also
1104
1146
specifies the error factor in time difference calculations. Default: 1e-6
1105
1147
:return: A prior object to pass to tsdate.date() containing prior values for
@@ -1160,16 +1202,18 @@ def build_grid(
1160
1202
tree_sequence ,
1161
1203
Ne ,
1162
1204
prior_distr = prior_distribution ,
1205
+ node_var_override = node_var_override ,
1163
1206
progress = progress ,
1164
1207
)
1165
- if np .any (tree_sequence .tables .nodes .time [tree_sequence .samples ()] != 0 ):
1208
+ tables = tree_sequence .tables
1209
+ if np .any (tables .nodes .time [tree_sequence .samples ()] > 0 ):
1166
1210
if not allow_historical_samples :
1167
1211
raise ValueError (
1168
1212
"There are samples at non-zero times, invalidating the conditional "
1169
1213
"coalescent prior. You can set allow_historical_samples=True to carry "
1170
1214
"on regardless, calculating a prior as if all samples were "
1171
1215
"contemporaneous (reasonable if you only have a few ancient samples)"
1172
1216
)
1173
- if truncate_priors :
1217
+ if np . any ( tables . nodes . time [ priors . fixed_node_ids ()] > 0 ) and truncate_priors :
1174
1218
priors = _truncate_priors (tree_sequence , priors , progress = progress )
1175
1219
return priors
0 commit comments