@@ -947,29 +947,55 @@ def gamma_cdf(t_set, alpha, beta):
947
947
return np .insert (t_set , 0 , 0 )
948
948
949
949
950
- def fill_priors (node_parameters , timepoints , ts , Ne , * , prior_distr , progress = False ):
950
+ def fill_priors (
951
+ node_parameters ,
952
+ timepoints ,
953
+ ts ,
954
+ Ne ,
955
+ * ,
956
+ prior_distr ,
957
+ node_var_override = None ,
958
+ progress = False ,
959
+ ):
951
960
"""
952
961
Take the alpha and beta values from the node_parameters array, which contains
953
- one row for each node in the TS (including fixed nodes)
954
- and fill out a NodeGridValues object with the prior values from the
955
- gamma or lognormal distribution with those parameters.
962
+ one row for each node in the TS (including fixed nodes, although alpha and beta
963
+ are ignored for these nodes) and fill out a NodeGridValues object with the prior
964
+ values from the gamma or lognormal distribution with those parameters.
965
+
966
+ For a description of `node_var_override`, see the parameter description in
967
+ the `build_grid` function.
956
968
957
969
TODO - what if there is an internal fixed node? Should we truncate
958
970
"""
959
971
if prior_distr == "lognorm" :
960
972
cdf_func = scipy .stats .lognorm .cdf
961
- main_param = np .sqrt (node_parameters [:, PriorParams .field_index ("beta" )])
973
+ shape_param = np .sqrt (node_parameters [:, PriorParams .field_index ("beta" )])
962
974
scale_param = np .exp (node_parameters [:, PriorParams .field_index ("alpha" )])
975
+
976
+ def shape_scale_from_mean_var (mean , var ):
977
+ a , b = lognorm_approx (mean , var )
978
+ return np .sqrt (b ), np .exp (a )
979
+
963
980
elif prior_distr == "gamma" :
964
981
cdf_func = scipy .stats .gamma .cdf
965
- main_param = node_parameters [:, PriorParams .field_index ("alpha" )]
966
- scale_param = 1 / node_parameters [:, PriorParams .field_index ("beta" )]
982
+ shape_param = node_parameters [:, PriorParams .field_index ("alpha" )]
983
+ scale_param = 1.0 / node_parameters [:, PriorParams .field_index ("beta" )]
984
+
985
+ def shape_scale_from_mean_var (mean , var ):
986
+ a , b = gamma_approx (mean , var )
987
+ return a , 1.0 / b
988
+
967
989
else :
968
990
raise ValueError ("prior distribution must be lognorm or gamma" )
969
-
991
+ if node_var_override is None :
992
+ node_var_override = {}
970
993
datable_nodes = np .ones (ts .num_nodes , dtype = bool )
971
994
datable_nodes [ts .samples ()] = False
995
+ # Mark all nodes in node_var_override as datable
996
+ datable_nodes [list (node_var_override .keys ())] = True
972
997
datable_nodes = np .where (datable_nodes )[0 ]
998
+
973
999
prior_times = base .NodeGridValues (
974
1000
ts .num_nodes ,
975
1001
datable_nodes [np .argsort (ts .tables .nodes .time [datable_nodes ])].astype (np .int32 ),
@@ -980,8 +1006,16 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
980
1006
for node in tqdm (
981
1007
datable_nodes , desc = "Assign Prior to Each Node" , disable = not progress
982
1008
):
1009
+ if node in node_var_override :
1010
+ shape , scale = shape_scale_from_mean_var (
1011
+ mean = ts .node (node ).time ,
1012
+ var = node_var_override [node ],
1013
+ )
1014
+ else :
1015
+ shape = shape_param [node ]
1016
+ scale = scale_param [node ]
983
1017
with np .errstate (divide = "ignore" , invalid = "ignore" ):
984
- prior_node = cdf_func (timepoints , main_param [ node ] , scale = scale_param [ node ] )
1018
+ prior_node = cdf_func (timepoints , shape , scale = scale )
985
1019
# force age to be less than max value
986
1020
prior_node = np .divide (prior_node , np .max (prior_node ))
987
1021
# prior in each epoch
@@ -994,7 +1028,7 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
994
1028
def _truncate_priors (ts , priors , progress = False ):
995
1029
"""
996
1030
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
997
- if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
1031
+ if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
998
1032
sequence
999
1033
"""
1000
1034
tables = ts .tables
@@ -1060,6 +1094,7 @@ def build_grid(
1060
1094
prior_distribution = "lognorm" ,
1061
1095
allow_historical_samples = None ,
1062
1096
truncate_priors = None ,
1097
+ node_var_override = None ,
1063
1098
eps = 1e-6 ,
1064
1099
progress = False ,
1065
1100
):
@@ -1094,6 +1129,13 @@ def build_grid(
1094
1129
priors of their direct ancestor nodes so that the probability of being younger
1095
1130
than the oldest descendant sample is zero. If the tree sequence is trustworthy
1096
1131
this should give better restults. Default: `True`
1132
+ :param dict node_var_override: is a dict mapping node IDs to a variance value.
1133
+ Any nodes listed here will be treated as non-fixed nodes whose prior is not
1134
+ calculated from the conditional coalescent but instead are allocated a prior
1135
+ whose mean is thenode time in the tree sequence and whose variance is the
1136
+ value in this dictionary. This allows sample nodes to be treated as nonfixed
1137
+ nodes, and therefore dated. If ``None`` (default) then all sample nodes are
1138
+ treated as occurring ata fixed time (as if this were an empty dict).
1097
1139
:param float eps: Specify minimum distance separating points in the time grid. Also
1098
1140
specifies the error factor in time difference calculations. Default: 1e-6
1099
1141
:return: A prior object to pass to tsdate.date() containing prior values for
@@ -1154,16 +1196,18 @@ def build_grid(
1154
1196
tree_sequence ,
1155
1197
Ne ,
1156
1198
prior_distr = prior_distribution ,
1199
+ node_var_override = node_var_override ,
1157
1200
progress = progress ,
1158
1201
)
1159
- if np .any (tree_sequence .tables .nodes .time [tree_sequence .samples ()] != 0 ):
1202
+ tables = tree_sequence .tables
1203
+ if np .any (tables .nodes .time [tree_sequence .samples ()] > 0 ):
1160
1204
if not allow_historical_samples :
1161
1205
raise ValueError (
1162
1206
"There are samples at non-zero times, invalidating the conditional "
1163
1207
"coalescent prior. You can set allow_historical_samples=True to carry "
1164
1208
"on regardless, calculating a prior as if all samples were "
1165
1209
"contemporaneous (reasonable if you only have a few ancient samples)"
1166
1210
)
1167
- if truncate_priors :
1211
+ if np . any ( tables . nodes . time [ priors . fixed_node_ids ()] > 0 ) and truncate_priors :
1168
1212
priors = _truncate_priors (tree_sequence , priors , progress = progress )
1169
1213
return priors
0 commit comments