@@ -1006,6 +1006,104 @@ def fill_priors(
1006
1006
return prior_times
1007
1007
1008
1008
1009
+ class MixturePrior :
1010
+ """
1011
+ Maps ConditionalCoalescentPrior onto nodes in a tree sequence and creates time-discretized priors
1012
+ """
1013
+
1014
+ def __init__ (self , tree_sequence , approximate_priors = False , approx_prior_size = None , prior_distribution = "lognorm" , allow_unary = False , progress = False ):
1015
+
1016
+ if approximate_priors :
1017
+ if not approx_prior_size :
1018
+ approx_prior_size = 1000
1019
+ else :
1020
+ if approx_prior_size is not None :
1021
+ raise ValueError (
1022
+ "Can't set approx_prior_size if approximate_prior is False"
1023
+ )
1024
+
1025
+ contmpr_ts , node_map = util .reduce_to_contemporaneous (tree_sequence )
1026
+ if contmpr_ts .num_nodes != tree_sequence .num_nodes :
1027
+ raise ValueError (
1028
+ "Passed tree sequence is not simplified and/or contains "
1029
+ "noncontemporaneous samples"
1030
+ )
1031
+ span_data = SpansBySamples (contmpr_ts , progress = progress , allow_unary = allow_unary )
1032
+
1033
+ base_priors = ConditionalCoalescentTimes (
1034
+ approx_prior_size , prior_distribution , progress = progress
1035
+ )
1036
+
1037
+ base_priors .add (contmpr_ts .num_samples , approximate_priors )
1038
+ for total_fixed in span_data .total_fixed_at_0_counts :
1039
+ # For missing data: trees vary in total fixed node count => have different priors
1040
+ if total_fixed > 0 :
1041
+ base_priors .add (total_fixed , approximate_priors )
1042
+ prior_params_contmpr = base_priors .get_mixture_prior_params (span_data )
1043
+
1044
+ # Map the nodes in the prior params back to the node ids in the original ts
1045
+ self .prior_params = prior_params_contmpr [node_map , :]
1046
+ self .base_priors = base_priors
1047
+ self .tree_sequence = tree_sequence
1048
+ self .prior_distribution = prior_distribution
1049
+
1050
+
1051
+ def make_discretized_prior (self , population_size , timepoints = 20 , progress = False ):
1052
+ """
1053
+ Calculate prior grid for a set of timepoints and a population size history
1054
+ """
1055
+
1056
+ if isinstance (population_size , np .ndarray ):
1057
+ if population_size .ndim != 2 :
1058
+ raise ValueError ("Array 'population_size' must be two-dimensional" )
1059
+ if population_size .shape [1 ] != 2 :
1060
+ raise ValueError (
1061
+ "Population size array must have two columns that contain \
1062
+ epoch start times and population sizes, respectively"
1063
+ )
1064
+ if np .any (population_size [:, 0 ] < 0.0 ):
1065
+ raise ValueError ("Epoch start times must be nonnegative" )
1066
+ if np .any (population_size [:, 1 ] <= 0.0 ):
1067
+ raise ValueError ("Population sizes must be positive" )
1068
+ if population_size [0 , 0 ] != 0 :
1069
+ raise ValueError ("The first epoch must start at time 0" )
1070
+ if not np .all (np .diff (population_size [:, 0 ]) > 0 ):
1071
+ raise ValueError ("Epoch start times must be unique and increasing" )
1072
+ else :
1073
+ if population_size <= 0 :
1074
+ raise ValueError ("Parameter 'population_size' must be greater than 0" )
1075
+ population_size = np .array ([[0 , population_size ]], dtype = float )
1076
+
1077
+ if isinstance (timepoints , int ):
1078
+ if timepoints < 2 :
1079
+ raise ValueError ("You must have at least 2 time points" )
1080
+ timepoints = create_timepoints (self .base_priors , timepoints + 1 )
1081
+ elif isinstance (timepoints , np .ndarray ):
1082
+ try :
1083
+ timepoints = np .sort (timepoints .astype (base .FLOAT_DTYPE , casting = "safe" ))
1084
+ except TypeError :
1085
+ raise TypeError ("Timepoints array cannot be converted to float dtype" )
1086
+ if len (timepoints ) < 2 :
1087
+ raise ValueError ("You must have at least 2 time points" )
1088
+ elif np .any (timepoints < 0 ):
1089
+ raise ValueError ("Timepoints cannot be negative" )
1090
+ elif np .any (np .unique (timepoints , return_counts = True )[1 ] > 1 ):
1091
+ raise ValueError ("Timepoints cannot have duplicate values" )
1092
+ else :
1093
+ raise ValueError ("time_slices must be an integer or a numpy array of floats" )
1094
+
1095
+ # Set all fixed nodes (i.e. samples) to have 0 variance
1096
+ priors = fill_priors (
1097
+ self .prior_params ,
1098
+ timepoints ,
1099
+ self .tree_sequence ,
1100
+ population_size ,
1101
+ prior_distr = self .prior_distribution ,
1102
+ progress = progress ,
1103
+ )
1104
+ return priors
1105
+
1106
+
1009
1107
def build_grid (
1010
1108
tree_sequence ,
1011
1109
population_size ,
@@ -1014,7 +1112,6 @@ def build_grid(
1014
1112
approximate_priors = False ,
1015
1113
approx_prior_size = None ,
1016
1114
prior_distribution = "lognorm" ,
1017
- eps = 1e-6 ,
1018
1115
# Parameters below undocumented
1019
1116
progress = False ,
1020
1117
allow_unary = False ,
@@ -1044,87 +1141,13 @@ def build_grid(
1044
1141
better fit, but slightly slower to calculate) or "gamma" for the gamma
1045
1142
distribution (slightly faster, but a poorer fit for recent nodes). Default:
1046
1143
"lognorm"
1047
- :param float eps: Specify minimum distance separating points in the time grid. Also
1048
- specifies the error factor in time difference calculations. Default: 1e-6
1049
1144
:return: A prior object to pass to tsdate.date() containing prior values for
1050
1145
inference and a discretised time grid
1051
1146
:rtype: base.NodeGridValues Object
1052
1147
"""
1053
- if isinstance (population_size , np .ndarray ):
1054
- if population_size .ndim != 2 :
1055
- raise ValueError ("Array 'population_size' must be two-dimensional" )
1056
- if population_size .shape [1 ] != 2 :
1057
- raise ValueError (
1058
- "Population size array must have two columns that contain \
1059
- epoch start times and population sizes, respectively"
1060
- )
1061
- if np .any (population_size [:, 0 ] < 0.0 ):
1062
- raise ValueError ("Epoch start times must be nonnegative" )
1063
- if np .any (population_size [:, 1 ] <= 0.0 ):
1064
- raise ValueError ("Population sizes must be positive" )
1065
- if population_size [0 , 0 ] != 0 :
1066
- raise ValueError ("The first epoch must start at time 0" )
1067
- if not np .all (np .diff (population_size [:, 0 ]) > 0 ):
1068
- raise ValueError ("Epoch start times must be unique and increasing" )
1069
- else :
1070
- if population_size <= 0 :
1071
- raise ValueError ("Parameter 'population_size' must be greater than 0" )
1072
- population_size = np .array ([[0 , population_size ]], dtype = float )
1073
- if approximate_priors :
1074
- if not approx_prior_size :
1075
- approx_prior_size = 1000
1076
- else :
1077
- if approx_prior_size is not None :
1078
- raise ValueError (
1079
- "Can't set approx_prior_size if approximate_prior is False"
1080
- )
1081
1148
1082
- contmpr_ts , node_map = util .reduce_to_contemporaneous (tree_sequence )
1083
- if contmpr_ts .num_nodes != tree_sequence .num_nodes :
1084
- raise ValueError (
1085
- "Passed tree sequence is not simplified and/or contains "
1086
- "noncontemporaneous samples"
1087
- )
1088
- span_data = SpansBySamples (contmpr_ts , progress = progress , allow_unary = allow_unary )
1089
-
1090
- base_priors = ConditionalCoalescentTimes (
1091
- approx_prior_size , prior_distribution , progress = progress
1149
+ mixture_prior = MixturePrior (
1150
+ tree_sequence , approximate_priors , approx_prior_size , prior_distribution , allow_unary , progress
1092
1151
)
1152
+ return mixture_prior .make_discretized_prior (population_size , timepoints )
1093
1153
1094
- base_priors .add (contmpr_ts .num_samples , approximate_priors )
1095
- for total_fixed in span_data .total_fixed_at_0_counts :
1096
- # For missing data: trees vary in total fixed node count => have different priors
1097
- if total_fixed > 0 :
1098
- base_priors .add (total_fixed , approximate_priors )
1099
-
1100
- if isinstance (timepoints , int ):
1101
- if timepoints < 2 :
1102
- raise ValueError ("You must have at least 2 time points" )
1103
- timepoints = create_timepoints (base_priors , timepoints + 1 )
1104
- elif isinstance (timepoints , np .ndarray ):
1105
- try :
1106
- timepoints = np .sort (timepoints .astype (base .FLOAT_DTYPE , casting = "safe" ))
1107
- except TypeError :
1108
- raise TypeError ("Timepoints array cannot be converted to float dtype" )
1109
- if len (timepoints ) < 2 :
1110
- raise ValueError ("You must have at least 2 time points" )
1111
- elif np .any (timepoints < 0 ):
1112
- raise ValueError ("Timepoints cannot be negative" )
1113
- elif np .any (np .unique (timepoints , return_counts = True )[1 ] > 1 ):
1114
- raise ValueError ("Timepoints cannot have duplicate values" )
1115
- else :
1116
- raise ValueError ("time_slices must be an integer or a numpy array of floats" )
1117
-
1118
- prior_params_contmpr = base_priors .get_mixture_prior_params (span_data )
1119
- # Map the nodes in the prior params back to the node ids in the original ts
1120
- prior_params = prior_params_contmpr [node_map , :]
1121
- # Set all fixed nodes (i.e. samples) to have 0 variance
1122
- priors = fill_priors (
1123
- prior_params ,
1124
- timepoints ,
1125
- tree_sequence ,
1126
- population_size ,
1127
- prior_distr = prior_distribution ,
1128
- progress = progress ,
1129
- )
1130
- return priors
0 commit comments