@@ -72,7 +72,6 @@ class ConditionalCoalescentTimes:
72
72
def __init__ (
73
73
self ,
74
74
precalc_approximation_n ,
75
- population_size ,
76
75
prior_distr = "lognorm" ,
77
76
progress = False ,
78
77
):
@@ -83,7 +82,6 @@ def __init__(
83
82
and therefore do not allow approximate priors to be used
84
83
"""
85
84
self .n_approx = precalc_approximation_n
86
- self .population_size = population_size
87
85
self .prior_store = {}
88
86
self .progress = progress
89
87
@@ -181,8 +179,7 @@ def add(self, total_tips, approximate=None):
181
179
priors [1 ] = PriorParams (alpha = 0 , beta = 1 , mean = 0 , var = 0 )
182
180
for var , tips in zip (variances , all_tips ):
183
181
# NB: it should be possible to vectorize this in numpy
184
- var = var * ((2 * self .Ne ) ** 2 ) # TODO
185
- expectation = self .tau_expect (tips , total_tips ) * 2 * self .Ne # TODO
182
+ expectation = self .tau_expect (tips , total_tips )
186
183
alpha , beta = self .func_approx (expectation , var )
187
184
priors [tips ] = PriorParams (
188
185
alpha = alpha , beta = beta , mean = expectation , var = var
@@ -976,10 +973,15 @@ def fill_priors(
976
973
datable_nodes = np .ones (ts .num_nodes , dtype = bool )
977
974
datable_nodes [ts .samples ()] = False
978
975
datable_nodes = np .where (datable_nodes )[0 ]
976
+
977
+ rescaled_timepoints = util .rescale_time_by_population_size (
978
+ timepoints , population_size
979
+ )
980
+
979
981
prior_times = base .NodeGridValues (
980
982
ts .num_nodes ,
981
983
datable_nodes [np .argsort (ts .tables .nodes .time [datable_nodes ])].astype (np .int32 ),
982
- timepoints ,
984
+ rescaled_timepoints ,
983
985
)
984
986
985
987
# TO DO - this can probably be done in an single numpy step rather than a for loop
@@ -1020,7 +1022,7 @@ def build_grid(
1020
1022
:param float population_size: The estimated (diploid) effective population
1021
1023
size: must be specified. May be a single value, or a two-column array with
1022
1024
epoch breakpoints and effective population sizes. Using standard (unscaled)
1023
- values for ``population_size`` results in a prior where times are measures
1025
+ values for ``population_size`` results in a prior where times are measured
1024
1026
in generations.
1025
1027
:param int_or_array_like timepoints: The number of quantiles used to create the
1026
1028
time slices, or manually-specified time slices as a numpy array. Default: 20
@@ -1041,9 +1043,34 @@ def build_grid(
1041
1043
inference and a discretised time grid
1042
1044
:rtype: base.NodeGridValues Object
1043
1045
"""
1044
- # TODO
1045
- if population_size <= 0 :
1046
- raise ValueError ("Parameter 'population_size' must be greater than 0" )
1046
+ if isinstance (population_size , np .ndarray ):
1047
+ if population_size .ndim != 2 :
1048
+ raise ValueError (
1049
+ "Parameter 'population_size' must be a scalar or a 2d array"
1050
+ )
1051
+ if population_size .shape [1 ] != 2 :
1052
+ raise ValueError (
1053
+ "'population_size' array must have two columns that contain \
1054
+ epoch start times and population sizes, respectively"
1055
+ )
1056
+ if np .any (population_size [:, 0 ] < 0.0 ):
1057
+ raise ValueError (
1058
+ "Epoch start times in 'population_size' array must be nonnegative"
1059
+ )
1060
+ if np .any (population_size [:, 1 ] <= 0.0 ):
1061
+ raise ValueError (
1062
+ "Population sizes in 'population_size' array must be positive "
1063
+ )
1064
+ if population_size [0 , 0 ] != 0 :
1065
+ raise ValueError (
1066
+ "The first epoch in 'population_size' array must start at time 0"
1067
+ )
1068
+ if not np .all (np .diff (population_size [:, 0 ]) > 0 ):
1069
+ raise ValueError (
1070
+ "Epoch start times 'population_size' array must be unique and increasing"
1071
+ )
1072
+ elif population_size <= 0 :
1073
+ raise ValueError ("Scalar 'population_size' must be greater than 0" )
1047
1074
if approximate_priors :
1048
1075
if not approx_prior_size :
1049
1076
approx_prior_size = 1000
@@ -1062,7 +1089,7 @@ def build_grid(
1062
1089
span_data = SpansBySamples (contmpr_ts , progress = progress , allow_unary = allow_unary )
1063
1090
1064
1091
base_priors = ConditionalCoalescentTimes (
1065
- approx_prior_size , population_size , prior_distribution , progress = progress
1092
+ approx_prior_size , prior_distribution , progress = progress
1066
1093
)
1067
1094
1068
1095
base_priors .add (contmpr_ts .num_samples , approximate_priors )
0 commit comments